In [4]:
%load_ext autoreload
%autoreload 2
import os, sys
sys.path.append('..')
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context("poster")
%load_ext line_profiler
import numpy as np
import pandas as pd
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (11,8)

In [32]:
from merf.utils import MERFDataGenerator
from merf.merf import MERF

Test Data Generation


In [42]:
dgm = MERFDataGenerator(m=.6, sigma_b=np.sqrt(4.5), sigma_e=1)

In [52]:
num_clusters_each_size = 1
train_sizes = [7, 9]
known_sizes = [63, 81]
new_sizes = [70, 90]

In [53]:
train_cluster_sizes = MERFDataGenerator.create_cluster_sizes_array(train_sizes, num_clusters_each_size)
known_cluster_sizes = MERFDataGenerator.create_cluster_sizes_array(known_sizes, num_clusters_each_size)
new_cluster_sizes = MERFDataGenerator.create_cluster_sizes_array(new_sizes, num_clusters_each_size)

In [54]:
len(train_cluster_sizes), len(known_cluster_sizes), len(new_cluster_sizes)


Out[54]:
(2, 2, 2)

In [55]:
train, test_known, test_new, training_cluster_ids, ptev, prev = dgm.generate_split_samples(train_cluster_sizes, known_cluster_sizes, new_cluster_sizes)


INFO     [utils.py:164] Drew 320 samples from 4 clusters.
INFO     [utils.py:165] PTEV = 89.20359622856809, PREV = 54.46396672949858.

In [56]:
len(train), len(test_known), len(test_new)


Out[56]:
(16, 144, 160)

In [57]:
train.head()


Out[57]:
y X_0 X_1 X_2 Z cluster
0 5.048030 0.403977 -0.642945 0.465257 1.0 0
1 6.265467 1.404621 -0.788585 1.625847 1.0 0
2 6.219084 1.879838 0.516437 -0.472274 1.0 0
3 4.223468 0.607317 0.646003 -0.740880 1.0 0
4 1.244755 -0.890759 -0.842052 -0.542031 1.0 0

MERF Training


In [58]:
X_train = train[['X_0', 'X_1', 'X_2']]
Z_train = train[['Z']]
clusters_train = train['cluster']
y_train = train['y']

In [59]:
mrf = MERF(n_estimators=300, max_iterations=100)
mrf.fit(X_train, Z_train, clusters_train, y_train)


INFO     [merf.py:235] GLL is 39.12237319347477 at iteration 1.
INFO     [merf.py:235] GLL is 38.70673341248194 at iteration 2.
INFO     [merf.py:235] GLL is 38.89880139438688 at iteration 3.
INFO     [merf.py:235] GLL is 38.900415164834875 at iteration 4.
INFO     [merf.py:235] GLL is 39.5152131695921 at iteration 5.
INFO     [merf.py:235] GLL is 39.68903766568897 at iteration 6.
INFO     [merf.py:235] GLL is 39.61351500837149 at iteration 7.
INFO     [merf.py:235] GLL is 39.901516494169 at iteration 8.
INFO     [merf.py:235] GLL is 38.88538452334219 at iteration 9.
INFO     [merf.py:235] GLL is 39.909458240100214 at iteration 10.
INFO     [merf.py:235] GLL is 39.887637024493884 at iteration 11.
INFO     [merf.py:235] GLL is 40.02736317608179 at iteration 12.
INFO     [merf.py:235] GLL is 40.424389526373616 at iteration 13.
INFO     [merf.py:235] GLL is 40.46047064163921 at iteration 14.
INFO     [merf.py:235] GLL is 39.89740797734385 at iteration 15.
INFO     [merf.py:235] GLL is 40.30870867601432 at iteration 16.
INFO     [merf.py:235] GLL is 40.29246228715476 at iteration 17.
INFO     [merf.py:235] GLL is 39.55518871342818 at iteration 18.
INFO     [merf.py:235] GLL is 39.93827936286138 at iteration 19.
INFO     [merf.py:235] GLL is 40.4103839246995 at iteration 20.
INFO     [merf.py:235] GLL is 40.22527109561858 at iteration 21.
INFO     [merf.py:235] GLL is 40.603911018665684 at iteration 22.
INFO     [merf.py:235] GLL is 40.971267992399575 at iteration 23.
INFO     [merf.py:235] GLL is 40.264924978905455 at iteration 24.
INFO     [merf.py:235] GLL is 40.11784549135437 at iteration 25.
INFO     [merf.py:235] GLL is 40.359625891450904 at iteration 26.
INFO     [merf.py:235] GLL is 40.64814596445594 at iteration 27.
INFO     [merf.py:235] GLL is 40.48479820129622 at iteration 28.
INFO     [merf.py:235] GLL is 40.95167919715522 at iteration 29.
INFO     [merf.py:235] GLL is 41.26522235710543 at iteration 30.
INFO     [merf.py:235] GLL is 41.51087989726108 at iteration 31.
INFO     [merf.py:235] GLL is 41.121740877458556 at iteration 32.
INFO     [merf.py:235] GLL is 41.41150527860117 at iteration 33.
INFO     [merf.py:235] GLL is 40.5826366990965 at iteration 34.
INFO     [merf.py:235] GLL is 40.80631781820975 at iteration 35.
INFO     [merf.py:235] GLL is 40.70738821506871 at iteration 36.
INFO     [merf.py:235] GLL is 41.52245114189566 at iteration 37.
INFO     [merf.py:235] GLL is 41.13298088751867 at iteration 38.
INFO     [merf.py:235] GLL is 41.16197727896002 at iteration 39.
INFO     [merf.py:235] GLL is 41.63037051047222 at iteration 40.
INFO     [merf.py:235] GLL is 41.55721269658952 at iteration 41.
INFO     [merf.py:235] GLL is 41.9044407366373 at iteration 42.
INFO     [merf.py:235] GLL is 41.5545778031117 at iteration 43.
INFO     [merf.py:235] GLL is 41.74861130713537 at iteration 44.
INFO     [merf.py:235] GLL is 42.90331335470315 at iteration 45.
INFO     [merf.py:235] GLL is 42.549126757089674 at iteration 46.
INFO     [merf.py:235] GLL is 42.381724153060816 at iteration 47.
INFO     [merf.py:235] GLL is 41.5903538749287 at iteration 48.
INFO     [merf.py:235] GLL is 43.18870758879184 at iteration 49.
INFO     [merf.py:235] GLL is 42.53916797084639 at iteration 50.
INFO     [merf.py:235] GLL is 42.15854699595499 at iteration 51.
INFO     [merf.py:235] GLL is 42.37649673906907 at iteration 52.
INFO     [merf.py:235] GLL is 42.76475704826963 at iteration 53.
INFO     [merf.py:235] GLL is 42.63554405262547 at iteration 54.
INFO     [merf.py:235] GLL is 43.03815597641354 at iteration 55.
INFO     [merf.py:235] GLL is 43.047903585732016 at iteration 56.
INFO     [merf.py:235] GLL is 43.39015426883333 at iteration 57.
INFO     [merf.py:235] GLL is 42.733955555456916 at iteration 58.
INFO     [merf.py:235] GLL is 41.81839477190927 at iteration 59.
INFO     [merf.py:235] GLL is 42.505641082914366 at iteration 60.
INFO     [merf.py:235] GLL is 43.63658324737081 at iteration 61.
INFO     [merf.py:235] GLL is 43.89011044441564 at iteration 62.
INFO     [merf.py:235] GLL is 43.5019399493946 at iteration 63.
INFO     [merf.py:235] GLL is 44.21761487764755 at iteration 64.
INFO     [merf.py:235] GLL is 43.91855690485879 at iteration 65.
INFO     [merf.py:235] GLL is 44.304163180756674 at iteration 66.
INFO     [merf.py:235] GLL is 44.92012528736535 at iteration 67.
INFO     [merf.py:235] GLL is 43.936618530991886 at iteration 68.
INFO     [merf.py:235] GLL is 44.54380538933578 at iteration 69.
INFO     [merf.py:235] GLL is 44.05666675776545 at iteration 70.
INFO     [merf.py:235] GLL is 44.50354187563584 at iteration 71.
INFO     [merf.py:235] GLL is 43.823972712507924 at iteration 72.
INFO     [merf.py:235] GLL is 44.57100235011918 at iteration 73.
INFO     [merf.py:235] GLL is 44.36244583329746 at iteration 74.
INFO     [merf.py:235] GLL is 45.355734578999055 at iteration 75.
INFO     [merf.py:235] GLL is 45.078166058759436 at iteration 76.
INFO     [merf.py:235] GLL is 45.56793346558064 at iteration 77.
INFO     [merf.py:235] GLL is 45.739807057822496 at iteration 78.
INFO     [merf.py:235] GLL is 45.43676977555185 at iteration 79.
INFO     [merf.py:235] GLL is 44.91755864451537 at iteration 80.
INFO     [merf.py:235] GLL is 44.881189755365014 at iteration 81.
INFO     [merf.py:235] GLL is 44.99503138614905 at iteration 82.
INFO     [merf.py:235] GLL is 45.394900444404854 at iteration 83.
INFO     [merf.py:235] GLL is 45.38803631060979 at iteration 84.
INFO     [merf.py:235] GLL is 44.89780513671174 at iteration 85.
INFO     [merf.py:235] GLL is 46.173624382352855 at iteration 86.
INFO     [merf.py:235] GLL is 45.98336106659436 at iteration 87.
INFO     [merf.py:235] GLL is 45.66343249701144 at iteration 88.
INFO     [merf.py:235] GLL is 45.71929053899903 at iteration 89.
INFO     [merf.py:235] GLL is 45.993396026071494 at iteration 90.
INFO     [merf.py:235] GLL is 46.11272384949217 at iteration 91.
INFO     [merf.py:235] GLL is 46.14000624149306 at iteration 92.
INFO     [merf.py:235] GLL is 46.129454519044266 at iteration 93.
INFO     [merf.py:235] GLL is 45.86386663027557 at iteration 94.
INFO     [merf.py:235] GLL is 46.10786158219537 at iteration 95.
INFO     [merf.py:235] GLL is 45.633466639809484 at iteration 96.
INFO     [merf.py:235] GLL is 46.46772491618914 at iteration 97.
INFO     [merf.py:235] GLL is 46.64055317420244 at iteration 98.
INFO     [merf.py:235] GLL is 47.05906984608626 at iteration 99.
INFO     [merf.py:235] GLL is 46.565940234228535 at iteration 100.
Out[59]:
<merf.merf.MERF at 0x1157770f0>

In [60]:
plt.figure(figsize=[15,10])
plt.subplot(221)
plt.plot(mrf.gll_history)
plt.grid('on')
plt.ylabel('GLL')
plt.xlabel('Iteration')

plt.subplot(222)
D_hat_history = [x[0][0] for x in mrf.D_hat_history]
plt.plot(D_hat_history)
plt.grid('on')
plt.ylabel('sigma_b2_hat')
plt.xlabel('Iteration')

plt.subplot(223)
plt.plot(mrf.sigma2_hat_history)
plt.grid('on')
plt.ylabel('sigma_e2_hat')
plt.xlabel('Iteration')

plt.subplot(224)
b_df = pd.concat(mrf.b_hat_history, axis=1)
b_df.columns = range(0, 101)
plt.plot(b_df.loc[0])
plt.plot(b_df.loc[1])
# plt.plot(b_df.loc[20])
# plt.plot(b_df.loc[30])
# plt.plot(b_df.loc[40])
# plt.plot(b_df.loc[50])
# plt.plot(b_df.loc[60])
plt.grid('on')
plt.ylabel('b_hat')
plt.xlabel('Iteration')


Out[60]:
Text(0.5,0,'Iteration')

In [61]:
mrf.trained_b.hist(bins=15)
plt.xlabel('b_i')
plt.title('Distribution of b_is')


Out[61]:
Text(0.5,1,'Distribution of b_is')

In [ ]: